#!/usr/bin/env python
# vim: set filetype=python sw=2 sts=2 et nofoldenable :
# This is a python script.
# If you encounter problemes when executing it on its own, start it with a python interpreter
import sys
import subprocess
import os
import shutil
import copy

def median(lst):
    lst = sorted(lst)
    if len(lst) < 1:
            return None
    if len(lst) %2 == 1:
            return lst[((len(lst)+1)/2)-1]
    else:
            return float(sum(lst[(len(lst)/2)-1:(len(lst)/2)+1]))/2.0

def get_solvers(file_path, file_base):
  solvers = []
  with open(file_path + os.sep + file_base + ".0", "r") as f:
    for line in f:
      if len(line) > 2 and line[0] == ":" and line[1] == ":":
        solvers.append(line.rstrip())
  return solvers

def info(file_path, file_base, ranks):
  print("ranks: " + str(ranks))
  print("solvers:")
  solvers = get_solvers(file_path, file_base)
  for solver in solvers:
    print(solver)

def dat(file_path, file_base, ranks):
  solvers = get_solvers(file_path, file_base)

  # defect / iters
  for solver in solvers:
    if solver.endswith("::VCycle"):
      continue
    with open(file_path + os.sep + "defect-iters"+solver.replace("::", "_")+".dat", "w") as out:
      for rank in range(0, ranks):
        with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
          out.write("\"rank " + str(rank) + "\"" + os.linesep)
          #search for solver/defect in corresponding rank file
          found_solver = False
          found_defect = False
          iter = 0
          defects_per_iter = {}
          for line in f:
            if found_defect:
              if "-" == line[0]:
                iter = 0
                continue
              if "#" in line:
                break
              if not defects_per_iter.has_key(iter):
                defects_per_iter[iter] = []
              defects_per_iter[iter].append(float(line))
              iter += 1
            if line.rstrip() == solver:
              found_solver = True
              continue
            if found_solver and "#defects" in line:
              found_defect = True
              continue

          for i in range(0, len(defects_per_iter)):
            out.write(str(i) + " " + str(sum(defects_per_iter[i]) / float(len(defects_per_iter[i]))) + " " + str(min(defects_per_iter[i])) + " " + str(max(defects_per_iter[i])) + os.linesep)

          if (rank != ranks - 1):
            #write out gnuplot dataset separator
            out.write(os.linesep + os.linesep)


  # time / ranks
  for solver in solvers:
    if solver.endswith("::VCycle"):
      continue
    with open(file_path + os.sep + "time-ranks"+solver.replace("::", "_")+".dat", "w") as out:
      out.write("rank a complete b c foundation d" + os.linesep);

      for rank in range(0, ranks):
        with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
          #search for solver/complete timings in corresponding rank file
          timings = []
          found_solver = False
          found_timing = False
          for line in f:
            if found_timing:
              if "-" == line[0]:
                continue
              if "#" in line or "::" in line:
                break
              timings.append(float(line))
            if line.rstrip() == solver:
              found_solver = True
              continue
            if found_solver and "#toe" in line:
              found_timing = True
              continue

          #process mpi_execute, directly following the last section
          #TODO add mpi_wait
          mpi_timings = []
          for line in f:
            if "-" == line[0]:
              continue
            if "#" in line or "::" in line:
              break
            mpi_timings.append(float(line))

          #out.write(str(rank) + " " + str(sum(timings) / float(len(timings))) + " " + str(min(timings)) + " " + str(max(timings)) + " "\
          #    + str(sum(mpi_timings) / float(len(mpi_timings))) + " " + str(min(mpi_timings)) + " " + str(max(mpi_timings)) + os.linesep)
          out.write(str(rank) + " " + str(median(timings)) + " " + str(min(timings)) + " " + str(max(timings)) + " "\
              + str(median(mpi_timings)) + " " + str(min(mpi_timings)) + " " + str(max(mpi_timings)) + os.linesep)


  # solver time / ranks
    lines = []
    reduction = 0.
    axpy = 0.
    spmv = 0.
    precon = 0.
    mpi = 0.
    for rank in range(0, ranks):
      with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
        newline = str(rank)
        #iterate to solver time entries in corresponding rank file
        while "#solver time" not in f.readline():
          pass
        temp = float(f.readline().rsplit(" ")[1])
        reduction += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        axpy += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        spmv += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        precon += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        mpi += temp
        newline += " " + str(temp)
        lines.append(newline + os.linesep)

    reduction /= float(ranks)
    axpy /= float(ranks)
    spmv /= float(ranks)
    precon /= float(ranks)
    mpi /= float(ranks)
    newline = "average " + str(reduction) + " " + str(axpy) + " " + str(spmv) + " " + str(precon) + " " + str(mpi) + os.linesep
    lines.insert(0, newline)
    lines.insert(0, "rank reduction axpy spmv precon foundation" + os.linesep)
    with open(file_path + os.sep + "global-time-ranks.dat", "w") as out:
      out.writelines(lines)


  # global size / ranks
    lines = []
    la = 0.
    domain = 0.
    mpi = 0.
    for rank in range(0, ranks):
      with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
        newline = str(rank)
        #iterate to global size entries in corresponding rank file
        while "#global size" not in f.readline():
          pass
        temp = float(f.readline().rsplit(" ")[1])
        temp = temp / 1024. / 1024.
        domain += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        temp = temp / 1024. / 1024.
        mpi += temp
        newline += " " + str(temp)
        temp = float(f.readline().rsplit(" ")[1])
        temp = temp / 1024. / 1024.
        la += temp
        newline += " " + str(temp)
        lines.append(newline + os.linesep)

    domain /= float(ranks)
    mpi /= float(ranks)
    la /= float(ranks)
    newline = "average " + str(domain) + " " + str(mpi) + " " + str(la) + os.linesep
    lines.insert(0, newline)
    lines.insert(0, "rank domain mpi la" + os.linesep)
    with open(file_path + os.sep + "global-size-ranks.dat", "w") as out:
      out.writelines(lines)


  # iters / ranks
  for solver in solvers:
    if solver.endswith("::VCycle"):
      continue
    with open(file_path + os.sep + "iters-ranks"+solver.replace("::", "_")+".dat", "w") as out:
      out.write("rank a iters b" + os.linesep);

      for rank in range(0, ranks):
        with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
          iters = []
          found_solver = False
          found_timing = False
          iter = 0
          for line in f:
            if found_timing:
              if "-" == line[0]:
                iters.append(iter)
                iter = 0
                continue
              if "#" in line or "::" in line:
                iters.append(iter)
                break
              iter += 1
            if line.rstrip() == solver:
              found_solver = True
              continue
            if found_solver and "#toe" in line:
              found_timing = True
              continue

          out.write(str(rank) + " " + str(sum(iters) / float(len(iters))) + " " + str(min(iters)) + " " + str(max(iters)) + os.linesep)



  # time / levels
  # use only data from rank 0 as a first shot (needs further intelligence later on for choosing ranks with min/max values)
  for solver in solvers:
    if not solver.endswith("::VCycle"):
      continue
    with open(file_path + os.sep + "time-levels"+solver.replace("::", "_")+".dat", "w") as out:
      out.write("level a complete b c foundation d" + os.linesep);

      rank = 0
      with open(file_path + os.sep + file_base + "." + str(rank), "r") as f:
        timings_per_level = {}
        mpi_timings_per_level = {}
        while solver not in f.readline():
          pass
        f.readline() # #defects
        f.readline() # #toe
        line = f.readline()
        level = 0
        while "#mpi_execute" not in line:
          if line[0] == "-":
            level = 0
          else:
            if not timings_per_level.has_key(level):
              timings_per_level[level] = []
            timings_per_level[level].append(float(line))
            level += 1
          line = f.readline()

        line = f.readline()
        level = 0
        while "::" not in line and "#" not in line and len(line) != 0:
          if line[0] == "-":
            level = 0
          else:
            if not mpi_timings_per_level.has_key(level):
              mpi_timings_per_level[level] = []
            mpi_timings_per_level[level].append(float(line))
            level += 1
          line = f.readline()

          #TODO add mpi_wait

        for level in range(0, len(timings_per_level)):
          out.write(str(level) + " " + str(median(timings_per_level[level])) + " " + str(min(timings_per_level[level])) + " " + str(max(timings_per_level[level])) + " "\
              + str(median(mpi_timings_per_level[level])) + " " + str(min(mpi_timings_per_level[level])) + " " + str(max(mpi_timings_per_level[level])) + os.linesep)


#TODO help, ? und guess nicht in sys.argv[0] suchen
# output help screen
if len(sys.argv) == 1 or (len(sys.argv) > 1 and ("help" in " ".join(sys.argv) or "?" in " ".join(sys.argv))):
  print ("Usage: create_dat [help] / [filename command]")
  print ("")
  print ("Commands:")
  print ("help")
  print ("info")
  print ("  print a list of available solvers, ranks, etc")
  print ("dat")
  print ("  create all gnuplot dat files")
  sys.exit()

filebase = sys.argv[1]

path_head, path_tail = os.path.split(filebase)
if (path_head == ""):
  path_head = "."

list_of_files = os.listdir(path_head)
if (path_tail + ".") in " ".join(list_of_files) and not os.path.isfile(filebase):
  path_tail = path_tail
elif not (path_tail + ".") in " ".join(list_of_files) and os.path.isfile(filebase) and "." in path_tail:
  path_tail = path_tail.rsplit(".", 1)[0]
else:
  print("confusion: " + filebase + " not filename filebase nor existing stat file")
  sys.exit()

ranks = 0
while os.path.isfile(path_head + os.sep + path_tail + "." + str(ranks)):
  ranks += 1

if (sys.argv[2] == "info"):
  info(path_head, path_tail, ranks)

elif (sys.argv[2] == "dat"):
  dat(path_head, path_tail, ranks)

else:
  print("command " + sys.argv[2] + " not known!")
